import os
import json
import re
import statistics
import numpy as np
from utils.dataset_utils import get_dataset
from utils import grader
import random
from collections import defaultdict
import itertools
import math

def print_knn_permutation(model, method, k, accuracy_dict):
    seed = 0
    np.random.seed(seed)
    random.seed(seed)
    permutations = []
    length = 0

    if(k>3):
        per_total = 10
    else:
        per_total = math.factorial(k)
    while length < per_total:
        perm = np.random.permutation(k)
        if(perm.tolist() not in permutations):
            permutations.append(perm.tolist())
            length += 1
    for i in range(len(permutations)):
        print(f"permutation {i}: {permutations[i]}")
        print(f"{accuracy_dict[model]['id'][method][k][i]:.4f}")
    
def extract_answer_gold(text):
    pattern = r"####\s(.*)"
    match = re.search(pattern, text, re.DOTALL)
    
    if match:
        answer = match.group(1).strip()
        if answer.lower() == "none":
            return "None"
        return answer
    else:
        return None

def extract_answer_pred(text,file_dir,i, model):
    pattern_full = r"####\s*(.*?)(?:\n####|\n\n|$)"
    match = re.search(pattern_full, text, re.DOTALL)
    if match:
        answer = match.group(1).strip()
        if answer.lower() == "none":
            return "None"
        return answer
    
    pattern_boxed = r"\\boxed{(.*?)}"
    match = re.search(pattern_boxed, text)
    if match:
        answer = match.group(1).strip()
        if answer.lower() == "none":
            return "None"
        return answer
        
    return None

def count_newlines(text):
    return text.count('\n')

def read_jsonl(file_dir, key='question'):
    results = []
    with open(file_dir, 'r', encoding='utf-8') as file:
        for line in file:
            data = json.loads(line)
            results.append(data[key])
    return results

def extract_answer_pred_old(text):
    pattern_full = r"####\s(.*?)\n\n"
    match = re.search(pattern_full, text, re.DOTALL)
    if match:
        return extract_digits(match.group(1).strip())
    else:
        pattern_fallback = r"####\s(.*)"
        match_fallback = re.search(pattern_fallback, text, re.DOTALL)
        
        if match_fallback:
            pattern_number = r"-?\d[\d,.]*"
            match_number = re.search(pattern_number, match_fallback.group(1))
            
            if match_number:
                return match_number.group(0).strip()
        
        return None

def extract_digits(text):
    text = text.replace(',', '')
    pattern = r'(?<!\d)-?(\d+(\.\d+)?)'
    matches = re.findall(pattern, text)
    result = ''.join(match[0] for match in matches)
    return result

def calculate_accuracy(gold_answers, pred_answers,file_dir, test_dataset, model):
    correct = 0
    null = 0
    bracket_class = 0
    bracket_stats = defaultdict(lambda: {'correct': 0, 'total': 0})
    for i, data in enumerate(gold_answers):
        
        if(test_dataset == "gsm8k" or test_dataset == "prm800k" or test_dataset == "gsm8k-1000" or test_dataset == "prm800k-1000"):
            bracket_count = pred_answers[i].count('\n')
            if (bracket_count <= 3):
                bracket_class = 1
            elif(bracket_count > 3):
                bracket_class = 2
        bracket_stats[bracket_class]['total'] += 1

        gold_answer = extract_answer_gold(data)

        try:
            pred_answer = extract_answer_pred(pred_answers[i], file_dir, i, model)
        except:
            print(f"File {file_dir} extraction error!")
            print(pred_answers[i])
            continue

        if pred_answer:
            pred_answer = pred_answer.replace(',', '')

        gold_answer = gold_answer.replace(',', '')

        if test_dataset == "gsm8k-plus-mini" and (pred_answer == "#### None" or gold_answer == "#### None"):
            if pred_answer == gold_answer:
                correct += 1
            else:
                null += 1
        elif grader.grade_answer(pred_answer, gold_answer):
            correct += 1
            if(test_dataset == "gsm8k" or test_dataset == "prm800k" or test_dataset == "mgsm"):
                bracket_stats[bracket_class]['correct'] += 1
        elif pred_answer is None:
            null += 1

    bracket_accuracies = {}
    for group, stats in bracket_stats.items():

        acc = stats['correct'] / stats['total'] if stats['total'] > 0 else 0
        bracket_accuracies[group] = acc

    accuracy = correct / len(gold_answers)

    return accuracy, null, bracket_accuracies

def main():

    models = ['llama-3.1-8b-instruction', 'gemma-2-9b-it', 'mistral-7b-Instruct-v0.3',]
    embs = ["all-roberta-large-v1"]
    ks = [0]

    methods = ['knn']
    base_dir =  './results'
    test_datasets = ["gsm8k-plus-mini"]
    name = {
        "knn": "topk",
        "diversity": "div",
        "knn_diversity": "topk-div",
        "random": "rand",
        "k_means": "K-Means"
    }
    train_datasets = test_datasets
    seed = 1 if ks == [0] else 10


    accuracy_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
    newline_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
    null_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
    bracket_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))))

    param_combinations = itertools.product(
        train_datasets, test_datasets, models, embs, ks, methods
    )


    for train_dataset, test_dataset, model, emb, k, method in param_combinations:
        _seed = 1 if 'knn' in method or 'k_means' in method else seed
        _, _, _, gold_answers = get_dataset(dataset=test_dataset, load_from_local=True)
        

        
        for i in range(_seed):
            if('knn' in method or 'k_means' in method and k > 1):
                permutation = 10
            else:
                permutation = 1
            for perm in range(permutation):
                file_dir = f"{base_dir}/{method}/{model}/{test_dataset}/{train_dataset}/{k}/{perm}/{i}/{emb}.jsonl"
                
                if not os.path.exists(file_dir):
                    print(f"File {file_dir} does not exist")
                    continue
                    
                pred_answers = read_jsonl(file_dir=file_dir, key="answer")


                if len(pred_answers) != len(gold_answers):
                    print(f"Warning: The number of predicted answers in file {file_dir} does not match the dataset size")
                    continue
                total_newlines = sum(count_newlines(answer) for answer in pred_answers)
                avg_newlines = total_newlines / len(pred_answers)
                accuracy, null_count, bracket_accuracies = calculate_accuracy(gold_answers, pred_answers, file_dir, test_dataset, model)
                
                result_type = 'id' if train_dataset == test_dataset else 'ood'
                accuracy_dict[model][result_type][method][k].append(accuracy)
                newline_dict[model][result_type][method][k].append(avg_newlines)
                null_dict[model][result_type][method][k].append(null_count)
                
                if result_type == 'id':
                    for bracket_class in bracket_accuracies:
                        bracket_dict[model]['id'][method][k][bracket_class].append(bracket_accuracies[bracket_class])
                
    for model in models:
        print(f"Model: {model}")
        print('-'*100)
        print("id:")
        if(accuracy_dict[model]['id']):
            for method in methods:
                for k in ks:
                    if(len(accuracy_dict[model]['id'][method][k]) > 1):
                        print(f"Method: {name[method]}, k: {k}, Average Accuracy: ${100*sum(accuracy_dict[model]['id'][method][k]) / len(accuracy_dict[model]['id'][method][k]):.2f}_{{{100*statistics.stdev(accuracy_dict[model]['id'][method][k]):.2f}}}$")
                    elif(len(accuracy_dict[model]['id'][method][k]) == 1):
                        print(f"Method: {name[method]}, k: {k}, Average Accuracy: ${100*accuracy_dict[model]['id'][method][k][0]:.2f}$")
                    elif(len(accuracy_dict[model]['id'][method][k]) == 0):
                        print(f"Method: {name[method]}, k: {k}, Average Accuracy: None")

if __name__ == "__main__":
    main()